{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6bd615f3-0e28-4ea8-8878-49fd62d833cd",
   "metadata": {},
   "source": [
    "# Select the selection algorithms\n",
    "\n",
    "The idea here is that we can simplify the decision logic, reduce the binary size\n",
    "and speed up the compilation time by only including a subset of selection algorithms.\n",
    "We're aiming to get algorithms that perform well in different situations, and complement\n",
    "each other - so to do this, we're iteratively removing the worst performing algorithm,\n",
    "after which algorithms are re-evaluated on their speedups relative to the remaining\n",
    "algorithms. This gets us a minimum spanning set of selection algorithms that performs\n",
    "well over diverse inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "fbf08b4e-7a91-4c4b-8320-ded040d3f827",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>key_type</th>\n",
       "      <th>index_type</th>\n",
       "      <th>algo</th>\n",
       "      <th>row</th>\n",
       "      <th>col</th>\n",
       "      <th>k</th>\n",
       "      <th>use_index_input</th>\n",
       "      <th>use_memory_pool</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix8bits</td>\n",
       "      <td>1</td>\n",
       "      <td>1024</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bits</td>\n",
       "      <td>1</td>\n",
       "      <td>1024</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bitsExtraPass</td>\n",
       "      <td>1</td>\n",
       "      <td>1024</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kWarpImmediate</td>\n",
       "      <td>1</td>\n",
       "      <td>1024</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kWarpFiltered</td>\n",
       "      <td>1</td>\n",
       "      <td>1024</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21405</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bits</td>\n",
       "      <td>7</td>\n",
       "      <td>1254</td>\n",
       "      <td>7696</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21406</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bitsExtraPass</td>\n",
       "      <td>7</td>\n",
       "      <td>1254</td>\n",
       "      <td>7696</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21407</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix8bits</td>\n",
       "      <td>7</td>\n",
       "      <td>2189</td>\n",
       "      <td>7960</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21408</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bits</td>\n",
       "      <td>7</td>\n",
       "      <td>2189</td>\n",
       "      <td>7960</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21409</th>\n",
       "      <td>float</td>\n",
       "      <td>int64_t</td>\n",
       "      <td>kRadix11bitsExtraPass</td>\n",
       "      <td>7</td>\n",
       "      <td>2189</td>\n",
       "      <td>7960</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.000019</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>21410 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      key_type index_type                   algo  row   col     k  \\\n",
       "0        float    int64_t            kRadix8bits    1  1024     1   \n",
       "1        float    int64_t           kRadix11bits    1  1024     1   \n",
       "2        float    int64_t  kRadix11bitsExtraPass    1  1024     1   \n",
       "3        float    int64_t         kWarpImmediate    1  1024     1   \n",
       "4        float    int64_t          kWarpFiltered    1  1024     1   \n",
       "...        ...        ...                    ...  ...   ...   ...   \n",
       "21405    float    int64_t           kRadix11bits    7  1254  7696   \n",
       "21406    float    int64_t  kRadix11bitsExtraPass    7  1254  7696   \n",
       "21407    float    int64_t            kRadix8bits    7  2189  7960   \n",
       "21408    float    int64_t           kRadix11bits    7  2189  7960   \n",
       "21409    float    int64_t  kRadix11bitsExtraPass    7  2189  7960   \n",
       "\n",
       "       use_index_input  use_memory_pool      time  \n",
       "0                    0                1  0.000017  \n",
       "1                    0                1  0.000012  \n",
       "2                    0                1  0.000012  \n",
       "3                    0                1  0.000009  \n",
       "4                    0                1  0.000010  \n",
       "...                ...              ...       ...  \n",
       "21405                0                1  0.000015  \n",
       "21406                0                1  0.000015  \n",
       "21407                0                1  0.000030  \n",
       "21408                0                1  0.000019  \n",
       "21409                0                1  0.000019  \n",
       "\n",
       "[21410 rows x 9 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from select_k_dataset import load_dataframe, get_dataset\n",
    "\n",
    "df = load_dataframe(\"select_k_times.json\")\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c0d9fed5-35c3-4b0b-987a-973e93e0c59c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "def rank_algos(df, use_relative_speedup=False):\n",
    "    _, y, weights = get_dataset(df)\n",
    "    times = Counter()\n",
    "    for algo, speedup in zip(y, weights):\n",
    "        times[algo] += speedup if use_relative_speedup else 1\n",
    "    return sorted(times.items(), key=lambda x:-x[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "56c5dd8e-6f44-4ef3-b825-1d5a5d6698a2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('kWarpDistributedShm', 1157),\n",
       " ('kRadix11bits', 1064),\n",
       " ('kWarpImmediate', 447),\n",
       " ('kRadix11bitsExtraPass', 369),\n",
       " ('kFaissBlockSelect', 302),\n",
       " ('kWarpDistributed', 42),\n",
       " ('kWarpFiltered', 23),\n",
       " ('kRadix8bits', 4)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show the number of times each algorithm is fastest for a given k/# of rows/# of cols / dtype / memory pool etc\n",
    "rank_algos(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ec63f794-0bdf-4afe-92a8-3fa8ab7a1648",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('kWarpDistributedShm', 1157),\n",
       " ('kRadix11bits', 1064),\n",
       " ('kWarpImmediate', 447),\n",
       " ('kRadix11bitsExtraPass', 369),\n",
       " ('kFaissBlockSelect', 302),\n",
       " ('kWarpDistributed', 42),\n",
       " ('kWarpFiltered', 23),\n",
       " ('kRadix8bits', 4)]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# kRadix8bits seems to have a performance issue with 64 bit index types, it is one\n",
    "# of the worst performing algorithms for 64bit indices, but one of the top 3 for 32 bit\n",
    "rank_algos(df[df.index_type == \"int64_t\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9b7de19f-ddb6-4fa6-b423-db384428d701",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rank_algos(df[df.index_type == \"uint32_t\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "bc0a10ea-652b-4822-8587-514c8f0348c3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "selected {'kWarpImmediate', 'kRadix11bitsExtraPass', 'kRadix11bits', 'kWarpDistributedShm'}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[('kWarpDistributedShm', 1266),\n",
       " ('kRadix11bits', 1156),\n",
       " ('kWarpImmediate', 577),\n",
       " ('kRadix11bitsExtraPass', 409)]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# do an algorithm selection pass, repeatedly remove the lowest performing algorithm\n",
    "#\n",
    "# The idea here is that we can simplify the decision logic, reduce the binary size\n",
    "# and speed up the compilation time by only including a subset of selection algorithms.\n",
    "# we're aiming to get algorithms that perform well in different situations, and complement\n",
    "# each other - so to do this, we're iteratively removing the worst performing algorithm,\n",
    "# after which algorithms are re-evaluated on their speedups relative to the remaining\n",
    "# algorithms. This gets us a minimum spanning set of selection algorithms that performs\n",
    "# well over diverse inputs.\n",
    "#\n",
    "# note: the lowest performing algorithm here might actually be pretty good, but\n",
    "# just not provide much benefit over another similar algorithm. \n",
    "# As an example, kWarpDistributed  is an excellent selection algorithm, but in testing \n",
    "# kWarpDistributedShm is slightly faster than it in situations where it does well, \n",
    "# meaning that it gets removed early on in this loop\n",
    "current = df[df.use_memory_pool == True]\n",
    "algos = set(df.algo)\n",
    "\n",
    "# we're arbitrarily getting this down to 3 selection algorithms\n",
    "while len(algos) > 4:\n",
    "    times = rank_algos(current, use_relative_speedup=False)\n",
    "    algo, speedup = times[-1]\n",
    "    algos.remove(algo)\n",
    "    current = df[df.algo.isin(algos)]\n",
    "\n",
    "print(\"selected\", algos)\n",
    "rank_algos(current)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "028ebbb1-5289-4104-a13c-494c74742087",
   "metadata": {},
   "outputs": [],
   "source": [
    "# experimenting with different subsets of index type / dtype / use memory seems\n",
    "# to pretty consistently show that kRadix11bits / kWarpDistributedShm / kFaissBlockSelect\n",
    "# all get selected here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18fb1217-388e-4530-a4f2-eb6bb0dc6e1f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
